{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regarding Warnings\n",
    "* Warning in the notebook is due to the discrepancy in package version between local and server environment\n",
    "* Readers can ignore them or update the packages to the required version through \"pip install [package]==[version]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Necessary Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from snowflake.ml.registry import model_registry\n",
    "from transformers import pipeline\n",
    "from snowflake.ml.model.models import huggingface_pipeline\n",
    "from snowflake.ml.model import deploy_platforms\n",
    "from snowflake.snowpark import Session\n",
    "import configparser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%pip install transformers\n",
    "#%pip install snowflake-ml-python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connect To Snowflake"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading Credentials From Config File\n",
    "snowflake_credentials_file = '../snowflake_creds.config'\n",
    "config = configparser.ConfigParser()\n",
    "config.read(snowflake_credentials_file)\n",
    "connection_parameters = dict(config['default'])\n",
    "\n",
    "session = Session.builder.configs(connection_parameters).create()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load HuggingFace llama2 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Admin\\AppData\\Local\\Temp\\ipykernel_9556\\2781243291.py:2: DeprecationWarning: \n",
      "The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.\n",
      "It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,\n",
      "except when specifically required. The old model registry will be removed once all its primary functionalities are\n",
      "fully integrated into the new registry.\n",
      "        \n",
      "  registry = model_registry.ModelRegistry(session=session, database_name=\"SNOWPARK\", schema_name=\"TUTORIAL\", create_if_not_exists=True)\n",
      "create_model_registry() is in private preview since 0.2.0. Do not use it in production. \n",
      "WARNING:absl:The database SNOWPARK already exists. Skipping creation.\n",
      "WARNING:absl:The schema SNOWPARK.TUTORIAL already exists. Skipping creation.\n"
     ]
    }
   ],
   "source": [
    "HF_AUTH_TOKEN = \"**************************\"\n",
    "registry = model_registry.ModelRegistry(session=session, database_name=\"SNOWPARK\", schema_name=\"TUTORIAL\", create_if_not_exists=True)\n",
    "llama_model = huggingface_pipeline.HuggingFacePipelineModel(task=\"text-generation\", model=\"meta-llama/Llama-2-7b-chat-hf\", token=HF_AUTH_TOKEN, return_full_text=False, max_new_tokens=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Register llama2 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MODEL_NAME = \"LLAMA2_MODEL_7b_CHAT\"\n",
    "# MODEL_VERSION = \"1\"\n",
    "# registry.delete_model( \n",
    "#     model_name=MODEL_NAME,\n",
    "#   model_version=MODEL_VERSION,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:snowflake.snowpark:ModelRegistry.log_model() is in private preview since 0.2.0. Do not use it in production. \n"
     ]
    }
   ],
   "source": [
    "MODEL_NAME = \"LLAMA2_MODEL_7b_CHAT\"\n",
    "MODEL_VERSION = \"1\"\n",
    "\n",
    "llama_model=registry.log_model(\n",
    "    model_name=MODEL_NAME,\n",
    "  model_version=MODEL_VERSION,\n",
    "    model=llama_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CREATION_CONTEXT</th>\n",
       "      <th>CREATION_ENVIRONMENT_SPEC</th>\n",
       "      <th>CREATION_ROLE</th>\n",
       "      <th>CREATION_TIME</th>\n",
       "      <th>ID</th>\n",
       "      <th>INPUT_SPEC</th>\n",
       "      <th>NAME</th>\n",
       "      <th>OUTPUT_SPEC</th>\n",
       "      <th>RUNTIME_ENVIRONMENT_SPEC</th>\n",
       "      <th>TYPE</th>\n",
       "      <th>URI</th>\n",
       "      <th>VERSION</th>\n",
       "      <th>ARTIFACT_IDS</th>\n",
       "      <th>DESCRIPTION</th>\n",
       "      <th>METRICS</th>\n",
       "      <th>TAGS</th>\n",
       "      <th>REGISTRATION_TIMESTAMP</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>None</td>\n",
       "      <td>{\\n  \"python\": \"3.8.16\"\\n}</td>\n",
       "      <td>\"ACCOUNTADMIN\"</td>\n",
       "      <td>2024-04-12 09:56:39.829000-07:00</td>\n",
       "      <td>980e6a80f8ed11eea87b34f39a51dc3f</td>\n",
       "      <td>None</td>\n",
       "      <td>LLAMA2_MODEL_7b_CHAT</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>huggingface_pipeline</td>\n",
       "      <td>sfc://SNOWPARK.TUTORIAL.SNOWML_MODEL_980E6A80F...</td>\n",
       "      <td>1</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>2024-04-12 09:56:42.126000-07:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  CREATION_CONTEXT   CREATION_ENVIRONMENT_SPEC   CREATION_ROLE  \\\n",
       "0             None  {\\n  \"python\": \"3.8.16\"\\n}  \"ACCOUNTADMIN\"   \n",
       "\n",
       "                     CREATION_TIME                                ID  \\\n",
       "0 2024-04-12 09:56:39.829000-07:00  980e6a80f8ed11eea87b34f39a51dc3f   \n",
       "\n",
       "  INPUT_SPEC                  NAME OUTPUT_SPEC RUNTIME_ENVIRONMENT_SPEC  \\\n",
       "0       None  LLAMA2_MODEL_7b_CHAT        None                     None   \n",
       "\n",
       "                   TYPE                                                URI  \\\n",
       "0  huggingface_pipeline  sfc://SNOWPARK.TUTORIAL.SNOWML_MODEL_980E6A80F...   \n",
       "\n",
       "  VERSION ARTIFACT_IDS DESCRIPTION METRICS  TAGS  \\\n",
       "0       1         None        None    None  None   \n",
       "\n",
       "            REGISTRATION_TIMESTAMP  \n",
       "0 2024-04-12 09:56:42.126000-07:00  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "registry.list_models().to_pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploy llama2 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_model.deploy(\n",
    "  deployment_name=\"llama_predict\",\n",
    "  platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,\n",
    "  options={\n",
    "            \"compute_pool\": \"snowpark_cs_compute_pool\",\n",
    "            \"num_gpus\": 1,\n",
    "            \"external_access_integrations\": [\"ALLOW_ALL_ACCESS_INTEGRATION\"]\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load News Category JSON To as Snowflake Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "news_dataset = pd.read_json(\"../datasets/News_Category_Dataset_v3.json\", lines=True).convert_dtypes()\n",
    "\n",
    "NEWS_DATA_TABLE_NAME = \"NEWS_DATASET\"\n",
    "news_dataset_sp_df = session.write_pandas(news_dataset,NEWS_DATA_TABLE_NAME,auto_create_table=True,quote_identifiers=False,overwrite=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
      "|\"LINK\"                                              |\"HEADLINE\"                                          |\"CATEGORY\"      |\"SHORT_DESCRIPTION\"                                 |\"AUTHORS\"             |\"DATE\"               |\n",
      "------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
      "|https://www.huffpost.com/entry/covid-boosters-u...  |Over 4 Million Americans Roll Up Sleeves For Om...  |U.S. NEWS       |Health experts said it is too early to predict ...  |Carla K. Johnson, AP  |1663891200000000000  |\n",
      "|https://www.huffpost.com/entry/american-airline...  |American Airlines Flyer Charged, Banned For Lif...  |U.S. NEWS       |He was subdued by passengers and crew when he f...  |Mary Papenfuss        |1663891200000000000  |\n",
      "|https://www.huffpost.com/entry/funniest-tweets-...  |23 Of The Funniest Tweets About Cats And Dogs T...  |COMEDY          |\"Until you have a dog you don't understand what...  |Elyse Wanshel         |1663891200000000000  |\n",
      "|https://www.huffpost.com/entry/funniest-parenti...  |The Funniest Tweets From Parents This Week (Sep...  |PARENTING       |\"Accidentally put grown-up toothpaste on my tod...  |Caroline Bologna      |1663891200000000000  |\n",
      "|https://www.huffpost.com/entry/amy-cooper-loses...  |Woman Who Called Cops On Black Bird-Watcher Los...  |U.S. NEWS       |Amy Cooper accused investment firm Franklin Tem...  |Nina Golgowski        |1663804800000000000  |\n",
      "|https://www.huffpost.com/entry/belk-worker-foun...  |Cleaner Was Dead In Belk Bathroom For 4 Days Be...  |U.S. NEWS       |The 63-year-old woman was seen working at the S...  |                      |1663804800000000000  |\n",
      "|https://www.huffpost.com/entry/reporter-gets-ad...  |Reporter Gets Adorable Surprise From Her Boyfri...  |U.S. NEWS       |\"Who's that behind you?\" an anchor for New York...  |Elyse Wanshel         |1663804800000000000  |\n",
      "|https://www.huffpost.com/entry/puerto-rico-wate...  |Puerto Ricans Desperate For Water After Hurrica...  |WORLD NEWS      |More than half a million people remained withou...  |DÁNICA COTO, AP       |1663804800000000000  |\n",
      "|https://www.huffpost.com/entry/mija-documentary...  |How A New Documentary Captures The Complexity O...  |CULTURE & ARTS  |In \"Mija,\" director Isabel Castro combined musi...  |Marina Fang           |1663804800000000000  |\n",
      "|https://www.huffpost.com/entry/biden-un-russian...  |Biden At UN To Call Russian War An Affront To B...  |WORLD NEWS      |White House officials say the crux of the presi...  |Aamer Madhani, AP     |1663718400000000000  |\n",
      "------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "news_dataset_sp_df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompting & Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[INST] &lt;&gt;\\nYour output will be parsed by a com...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[INST] &lt;&gt;\\nYour output will be parsed by a com...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[INST] &lt;&gt;\\nYour output will be parsed by a com...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[INST] &lt;&gt;\\nYour output will be parsed by a com...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[INST] &lt;&gt;\\nYour output will be parsed by a com...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               input\n",
       "0  [INST] <>\\nYour output will be parsed by a com...\n",
       "1  [INST] <>\\nYour output will be parsed by a com...\n",
       "2  [INST] <>\\nYour output will be parsed by a com...\n",
       "3  [INST] <>\\nYour output will be parsed by a com...\n",
       "4  [INST] <>\\nYour output will be parsed by a com..."
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import snowflake.snowpark.functions as F\n",
    "\n",
    "prompt_prefix = \"\"\"[INST] <>\n",
    "Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {\"properties\": {\"category\": {\"type\": \"string\",\"description\": \"The category that the news should belong to.\"},\"keywords\": {\"type\": \"array\":\"description\": \"The keywords that are mentioned in the news.\",\"items\": [{\"type\": \"string\"}]},\"importance\": {\"type\": \"number\",\"description\": \"A integer from 1 to 10 to show if the news is important. The higher the number, the more important the news is.\"}},\"required\": [\"properties\",\"keywords\",\"importance\"]}\n",
    " \n",
    "As an example, input \"Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place\" results in the json: {\"category\": \"Natural Disasters\",\"keywords\": [\"evacuate\", \"wildfire\", \"Washington state\", \"medical facilities\"],\"importance\": 8}\n",
    "<>\n",
    "\"\"\"\n",
    "prompt_suffix = \"[/INST]\"\n",
    "\n",
    "df_inputs = news_dataset_sp_df.with_column('\"input\"',F.concat_ws(F.lit(\" \"),F.lit(prompt_prefix),F.col('SHORT_DESCRIPTION'),F.lit(prompt_suffix))).select('\"input\"')\n",
    "df_inputs.to_pandas().head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = llama_model_ref.predict(\n",
    "    deployment_name=DEPLOYMENT_NAME,\n",
    "    data=input_df\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "def_gui_3.8_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
